package com.whoyx.jiebing.utils.nlp.word_encoder;

import ai.djl.util.Utils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

/**
 * @author Stone
 */
public class WordEncoder {

    List<String> words = null;
    float[][] embeddings = null;

    public WordEncoder(Path vocab, Path embedding) throws IOException {
        try (InputStream is = Files.newInputStream(new File(vocab.toString()).toPath())) {
            words = Utils.readLines(is, false);
        }

        File file = new File(embedding.toString());
        try (INDArray array = Nd4j.readNpy(file)) {
            embeddings = array.toFloatMatrix();
        }
    }

    public float[] search(String word) {
        for (int i = 0; i < words.size(); i++) {
            if (words.get(i).equals(word)) {
                return embeddings[i];
            }
        }
        return null;
    }

}
